import os
from multiprocessing import Pool

import multiprocessing
from torch.utils.data import Dataset
import openpyxl
from data_generator import Generator
from load import get_gnn_inputs
from models import GNN_multiclass, GNN_multiclass_second_period
import time
import argparse
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch import Tensor
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from get_infor_from_first_step import get_second_period_labels_single, test_first_get_second_period_labels_single, \
    in_sample_test_first_get_second_period_labels_single, imbalanced_test_first_get_second_period_labels_single
from losses import compute_loss_multiclass, compute_accuracy_multiclass, compute_accuracy_spectral
from load_local_refinement import get_gnn_inputs_local_refinement
from controlsnr import find_a_given_snr
import sys
import os
import time
import numpy as np
from scipy.sparse.csgraph import laplacian
from scipy.linalg import eigh
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from scipy.sparse import issparse
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

template_header = '{:<6} {:<10} {:<10} {:<10}'
template_row = '{:<6d} {:<10.4f} {:<10.2f} {:<10.2f}'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cached_graphs = []
cached_labels = []


##Define the train function we need to train the first-period GNN function
def train_batch_first_period(gnn, optimizer, batch, n_classes, iter, device, args):
    """
    使用 batched 输入训练 GNN，适配用户自定义的 permutation-aware 损失函数。
    """
    gnn.train()
    Ws = batch['adj'].to(device)  # shape: (B, N, N)
    labels = batch['labels'].to(device)  # shape: (B, N)

    start = time.time()

    # ✅ 调用 batched GNN 输入处理
    WW, x = get_gnn_inputs(Ws.cpu().numpy(), args.J)  # 输出：WW: (B, N, N, J+3), x: (B, N, d)
    WW = WW.clone().detach().to(torch.float32).to(device)
    x = x.clone().detach().to(torch.float32).to(device)

    optimizer.zero_grad(set_to_none=True)

    # ✅ 前向传播，输出 shape: (B, N, n_classes)
    pred = gnn(WW, x)

    # ✅ 使用你自己的 permutation-aware loss（已内部处理 batch）
    loss = compute_loss_multiclass(pred, labels, n_classes)  # 无需 reshape
    loss.backward()

    # total_norm = torch.norm(torch.stack([
    #     p.grad.detach().data.norm(2)
    #     for p in gnn.parameters() if p.grad is not None
    # ]), 2).item()
    # print(f"梯度范数 = {total_norm:.4f}")

    # ✅ 梯度裁剪 + 参数更新
    nn.utils.clip_grad_norm_(gnn.parameters(), args.clip_grad_norm)
    optimizer.step()

    # ✅ 使用你自己的 accuracy 函数
    acc, _ = compute_accuracy_multiclass(pred, labels, n_classes)

    elapsed_time = time.time() - start
    loss_value = loss.item()

    # ✅ 打印信息
    print(template_header.format(*['iter', 'avg loss', 'avg acc', 'elapsed']))
    print(template_row.format(iter, loss_value, acc, elapsed_time))

    return loss_value, acc

def from_scores_to_labels_multiclass_batch(pred):
    labels_pred = np.argmax(pred, axis = 2).astype(int)
    return labels_pred

def evaluate_on_loader(gnn, val_loader, n_classes, args, device):
    gnn.train()
    total_loss, total_acc,total_nmi, total_ari= 0, 0,0,0
    with torch.no_grad():
        for batch in val_loader:
            Ws = batch['adj'].to(device)
            labels = batch['labels'].to(device)

            WW, x = get_gnn_inputs(Ws.cpu().numpy(), args.J)
            WW = WW.clone().detach().to(torch.float32).to(device)
            x = x.clone().detach().to(torch.float32).to(device)

            pred = gnn(WW, x)
            loss = compute_loss_multiclass(pred, labels, n_classes)
            acc, _ = compute_accuracy_multiclass(pred, labels, n_classes)

            pred = pred.data.cpu().numpy()
            labels_cpu = labels.data.cpu().numpy()
            batch_size = pred.shape[0]
            pred_cpu = from_scores_to_labels_multiclass_batch(pred)
            labels_cpu = labels_cpu.flatten()  # 形状: (1000,)
            pred_cpu = pred_cpu.flatten()

            ari = adjusted_rand_score(labels_cpu, pred_cpu)
            nmi = normalized_mutual_info_score(labels_cpu, pred_cpu)

            total_loss += loss.item()
            total_acc += acc
            total_nmi += nmi
            total_ari += ari

    avg_loss = total_loss / len(val_loader)
    avg_acc = total_acc / len(val_loader)
    avg_nmi = total_nmi / len(val_loader)
    avg_ari = total_ari / len(val_loader)
    return avg_loss, avg_acc, avg_ari, avg_nmi

def train_first_period_with_early_stopping(
    gnn,
    train_loader,
    val_loader,
    n_classes,
    args,
    epochs: int = 100,
    patience: int = 3,
    save_path: str = 'best_model.pt',
    filename: str = "filename_first",
    acc_eps: float = 1e-8,
    loss_eps: float = 1e-12,
):
    """
    早停策略：优先比较 val_acc；若 val_acc 基本相同（|Δacc|<=acc_eps），则比较 val_loss（更小者更优）。
    """
    gnn.train()
    optimizer = torch.optim.Adamax(gnn.parameters(), lr=args.lr)

    loss_lst, acc_lst = [], []
    best_val_nmi = -1.0
    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        gnn.train()

        for iter_idx, batch in enumerate(tqdm(train_loader)):
            loss, acc = train_batch_first_period(
                gnn=gnn,
                optimizer=optimizer,
                batch=batch,
                n_classes=n_classes,
                iter=iter_idx,
                device=device,
                args=args
            )
            loss_lst.append(loss)
            acc_lst.append(acc)

        # 🧪 验证集评估
        val_loss, val_acc, val_nmi, val_ari = evaluate_on_loader(
            gnn, val_loader, n_classes, args, device=device
        )

        print(f"Validation Loss: {val_loss:.6f}, NMI: {val_nmi:.6f},  Accuracy: {val_acc:.6f}")

        # 中间快照（与你原来一致）
        torch.save(gnn.cpu(), filename)
        if torch.cuda.is_available():
            gnn = gnn.to(device)

        # ✅ 刷新最佳：先比 acc；若 acc 打平，再比 loss
        improved = False
        if val_nmi > best_val_nmi + acc_eps:
            reason = "val_acc improved"
            improved = True
        elif abs(val_nmi - best_val_nmi) <= acc_eps and val_loss < best_val_loss - loss_eps:
            reason = "val_acc tie, val_loss improved"
            improved = True
        else:
            reason = None

        if improved:
            best_val_nmi = val_nmi
            best_val_loss = val_loss
            patience_counter = 0

            # 保存最佳模型（与你原来一致）
            torch.save(gnn.cpu(), save_path)
            # ✅ 可选：更安全的保存方式（推荐）
            # torch.save(gnn.state_dict(), save_path)

            print(f"New best model saved ({reason}). best_acc={best_val_nmi:.6f}, best_loss={best_val_loss:.6f}")
            if torch.cuda.is_available():
                gnn = gnn.to(device)
        else:
            patience_counter += 1
            print(f"No improvement ({patience_counter}/{patience}). "
                  f"best_acc={best_val_nmi:.6f}, best_loss={best_val_loss:.6f}")

        # ⛔ 提前停止
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

        torch.cuda.empty_cache()  # 可选：按 epoch 清一次

    return loss_lst, acc_lst
